# File:     JOP_RR1_financial_topics.R
# Purpose:  This script applies Nicolo's financial topics to the data
# Input:    /Data/finalData.RData
# Output:   
# Author:   JB


rm(list = ls())
require(tidyverse)
require(ggridges)

setwd('C:/Users/Jimbo/Dropbox/FED/FED/Paper/JOP/RR1_replication/')

load('./data/finalData.RData')
library(openai)

API <- commandArgs(trailingOnly = T)
API <- 'JB'

if(API == 'JC') {
  Sys.setenv(OPENAI_API_KEY = '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') # Josh
  cat('Josh\n')
} else {
  Sys.setenv(OPENAI_API_KEY = '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') # james.h.bisbee@vanderbilt.edu #2
  cat('Jim\n')
}
# Sys.setenv(OPENAI_API_KEY = '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') # jhb362@nyu.edu #1


create_prompt <- function(chunk,direction = 'more') {
  res <- list(
    list(
      "role" = "system",
      "content" = "An aggressive communication style is a way of communicating with others than involves assertiveness, dominance, bluntness, verbal attacks, ignoring boundaires, hostility, lack of empathy, manipulation, and defensiveness."
    ),
    list(
      "role" = "user",
      "content" = stringr::str_c(
        'Please read the following conversations between a chair of the Federal Reserve and a member of Congress. Out of the two examples, which conversation is ',
        direction,' aggressive overall? Within the selected conversation, which speaker is more aggressive?\n\n',
        chunk)
    )
  )
  return(res)
}

clean_result <- function(result,direction = 'more') {
  res <- list(
    list(
      "role" = "user",
      "content" = stringr::str_c(
        result,
        '\n\nIn the above text, which conversation is described as ',direction,' polite? Return only the number of the conversation.')
    )
  )
  return(res)
}

submit_openai <- function(prompt, temperature = 0.2, n = 1) {
  res <- openai::create_chat_completion(model = "gpt-3.5-turbo",
                                        messages = prompt,
                                        temperature = temperature,
                                        n = n)
  Sys.sleep(1)
  res
}


# Can we do this pairwise?
toSample <- utterance_level %>%
  # filter(docID == 'fed2001-02-13.txt') %>%
  arrange(docID,ind) %>%
  select(docID,chamber,date,speaker,opensecretsID,ind,nchars,textclean) %>%
  # filter(nchars < 1000) %>%
  group_by(docID) %>%
  mutate(firstFED = ifelse(grepl("FED",opensecretsID),ind,NA)) %>%
  mutate(firstFED = min(firstFED,na.rm=T)) %>%
  # select(docID,ind,firstFED,speaker) %>%
  filter(ind >= firstFED) %>%
  arrange(docID,ind) %>%
  select(docID,chamber,date,ind,speaker,textclean,opensecretsID) %>%
  mutate(delta = ind - lag(ind)) %>%
  filter(delta == 1) %>%
  mutate(delta2 = ind - lag(ind)) %>%
  mutate(delta2 = ifelse(is.na(delta2),1,delta2)) %>%
  # mutate(chunkIndicator = cumsum(delta2 != 1)) %>%
  mutate(chunkIndicator = cumsum(speaker != lag(speaker,2,default = speaker[1]))) %>%
  mutate(chunkIndicator = ifelse(chunkIndicator == (lead(chunkIndicator) - 1), lead(chunkIndicator),chunkIndicator)) %>%
  mutate(chunkIndicator = ifelse(chunkIndicator == (lag(chunkIndicator) + 1) & (chunkIndicator == lead(chunkIndicator) - 1), 
                                 lag(chunkIndicator),chunkIndicator)) %>%
  group_by(docID,chunkIndicator) %>%
  mutate(n = n()) %>%
  ungroup() %>%
  mutate(nchars = nchar(textclean)) %>%
  # ggplot(aes(x = nchars)) + 
  # geom_histogram()
  # filter(nchars < 5000) %>%
  # group_by(docID) %>%
  # mutate(delta2 = ind - lag(ind)) %>%
  # mutate(delta2 = ifelse(is.na(delta2),1,delta2)) %>%
  # mutate(chunkIndicator = cumsum(delta2 != 1)) %>%
  # group_by(docID,chunkIndicator) %>%
  # mutate(n = n()) %>%
  # arrange(desc(n))
  filter(n > 2) %>%
  mutate(textclean = paste0(gsub('\\.$','',speaker),': ',textclean)) %>%
  group_by(docID,chamber,date,chunkIndicator,n) %>%
  summarise(text = paste(textclean,collapse = '\n')) %>%
  ungroup() %>%
  mutate(fed = ifelse(date < as.Date('2006-01-01'),'Greenspan',
                      ifelse(date < as.Date('2014-01-01'),'Bernanke',
                             ifelse(date < as.Date('2018-01-01'),'Yellen','Powell')))) %>%
  mutate(nchars = nchar(text)) %>%
  rowwise() %>%
  filter(grepl(fed,text)) %>%
  ungroup()

toSample %>%
  ggplot(aes(x = n)) + 
  geom_histogram()

set.seed(123)
chunks <- list()
counter <- 1
for(nConv in c(3:12)) {
  qntls <- quantile(toSample %>%
             filter(n == nConv) %>% 
             pull(nchars))
  
  for(q in 2:length(qntls)) {
    tmp <- toSample %>%
      filter(n == nConv,
             nchars < qntls[q],
             nchars > qntls[q-1])
    
    # test <- NULL
    for(i in 1:50) {
      tmp2 <- tmp %>%
        group_by(fed) %>%
        sample_n(size = 1) %>%
        ungroup() %>%
        sample_n(size = 2) %>%
        slice(sample(1:2,2))
      
      chunks[[counter]] <- list()
      chunks[[counter]]$chunk <- paste(paste0('Conversation ',1:2,':\n',
                                              tmp2 %>% pull(text),collapse = '\n\n'))
      chunks[[counter]]$srcs <- tmp2
      counter <- counter + 1
    }
  }
}

test <- NULL
for(i in 1:length(chunks)) {
  test <- test %>%
    bind_rows(chunks[[i]]$srcs %>%
                mutate(index = i,
                       rown = row_number()))
}

test %>%
  count(fed,rown)

res <- NULL
for(i in 1:length(chunks)) {
  cat('----------------------------\n',i,'\n----------------------------\n')
  for(d in c('more','less')) {
    for(rev in c(T,F)) {
    # stop()
    # prompts <- create_prompt(gsub('A','1',gsub('1','2',gsub('2','A',chunks[[i]]$chunk))))
      if(rev) {
        torev <- str_split(chunks[[i]]$chunk,pattern = '(\n\nConversation 2:)')[[1]]
        chnk <- paste(paste0('Conversation 1:',torev[2]),
              gsub('Conversation 1','Conversation 2',torev[1]),sep = '\n\n')
        # cat(chnk)
        srcs <- chunks[[i]]$srcs %>% slice(2,1)
      } else {
        chnk <- chunks[[i]]$chunk
        srcs <- chunks[[i]]$srcs
      }
      
    prompts <- create_prompt(chunk = chnk,direction = d)
    if(nchar(prompts[[2]]$content) > 10000) { next }
    
    # cat(prompts[[2]]$content)
    Sys.sleep(2)
    system.time(openai_completions <- try(submit_openai(prompt = prompts,temperature = 0,n = 1)))
    
    while(class(openai_completions) == 'try-error') {
      Sys.sleep(5)
      system.time(openai_completions <- try(submit_openai(prompt = prompts,temperature = 0,n = 1)))
    }
    # cln <- try(submit_openai(prompt = clean_result(result = openai_completions$choices$message.content,direction = d),
    #                          temperature = 0,n = 1))
    # while(class(cln) == 'try-error') {
    #   Sys.sleep(5)
    #   cln <- try(submit_openai(prompt = clean_result(result = openai_completions$choices$message.content),temperature = 0,n = 1))
    # }
    
    # openai_completions
    res <- res %>%
      bind_rows(srcs %>%
                  select(-docID,-chunkIndicator) %>%
                  mutate(id = row_number()) %>%
                  pivot_wider(names_from = id,values_from = c('date','n','text','fed','nchars','chamber')) %>%
                  mutate(#choice = cln$choices$message.content,
                         explanation = openai_completions$choices$message.content,
                         direction = d,
                         reversed = rev))
    }
    # openai_completions$usage
    
  }
}

save(res,file = './output/chatGPT_polite_BTM_aggression.RData')


res %>%
  mutate(choice = tolower(str_extract(explanation,'\\d|Neither|neither'))) %>%
  count(choice,reversed)

cat(chunks[[101]]$chunk)


# Try with simpler coding

create_prompt <- function(chunk) {
  res <- list(
    list(
      "role" = "system",
      "content" = "A conflictual conversation is one where the speakers disagree with each other."
    ),
    list(
      "role" = "user",
      "content" = stringr::str_c(
        'Please read the following conversation. On a scale of 1 to 10, with 1 being the least conflictual and 10 being the most conflictual, how conflictual is this conversation? Only return the number.\n\n',
        chunk)
    )
  )
  return(res)
}

submit_openai <- function(prompt, temperature = 0.2, n = 1) {
  res <- openai::create_chat_completion(model = "gpt-3.5-turbo",
                                        messages = prompt,
                                        temperature = temperature,
                                        n = n)
  Sys.sleep(1)
  res
}

cat((prompt <- create_prompt(chunk = toSample %>%
  filter(grepl('rude',text),
         fed == 'Yellen') %>%
  slice(1) %>%
  pull(text)))[[2]]$content)

completion <- submit_openai(prompt,temperature = 0,n = 1)
